import jax.numpy as jnp

def invertibility(x,x_approx):
    return jnp.mean((x - x_approx)**2)

def low_dimensionality(z,n_dim):
    idx = z.shape[1]-n_dim
    l1_norm = jnp.mean(jnp.abs(z[:,:idx]),axis=1)
    return jnp.mean(l1_norm)

def isometry(z,dij,dij_max):
    d_zij = jnp.sqrt(jnp.sum((z[None] - z[:,None])**2,axis=-1)+1e-8)
    return jnp.mean(((dij-d_zij)/dij_max)**2)


